/*
 * UTF - 8
 * 在802.111be中开启OFDMA;
 * 实际上OFDMA是802.11ax中的技术，只是因为最新的标准是be，所以就这么做了
 * 网络拓扑图如下：仅使用了一个channel，比MLD的要简单很多
 * ap * --- x GHz, y Mhz --- * sta1
 *                           * sta2
 *                           * sta3
 *                           * sta4
 * 如果是downlink，即ap向sta发送数据的情况，在pcap文件中可以看到同一时间内ap向不同的ip地址发送数据帧；
 * 也可以看到不同的sta在接收到ack触发帧后再同一时间回复ack
 * 
 * ns-3对OFDMA的实现有两个方面，一是AP同时多发，sta同时ack，二是再ap的协调下sta同时上传数据。
 */
#include "ns3/boolean.h"
#include "ns3/callback.h"
#include "ns3/command-line.h"
#include "ns3/config.h"
#include "ns3/double.h"
#include "ns3/eht-phy.h"
#include "ns3/enum.h"
#include "ns3/frame-exchange-manager.h"
#include "ns3/internet-stack-helper.h"
#include "ns3/ipv4-address-generator.h"
#include "ns3/ipv4-address-helper.h"
#include "ns3/ipv4.h"
#include "ns3/log.h"
#include "ns3/mobility-helper.h"
#include "ns3/multi-model-spectrum-channel.h"
#include "ns3/on-off-helper.h"
#include "ns3/packet-sink-helper.h"
#include "ns3/packet-sink.h"
#include "ns3/rng-seed-manager.h"
#include "ns3/spectrum-wifi-helper.h"
#include "ns3/ssid.h"
#include "ns3/string.h"
#include "ns3/udp-client-server-helper.h"
#include "ns3/uinteger.h"
#include "ns3/wifi-acknowledgment.h"
#include "ns3/wifi-mac.h"
#include "ns3/wifi-net-device.h"
#include "ns3/wifi-phy-state-helper.h"
#include "ns3/wifi-phy.h"

#include <array>
#include <chrono>
#include <functional>
#include <iomanip>
#include <iostream>
#include <numeric>
#include <sstream>

using namespace ns3;
const std::string thisFileName{"multiStaOFDMA"};

NS_LOG_COMPONENT_DEFINE("EhtStaApLinkLink");

/**
 * @brief Writes a log string to the log file.
 *
 * @param logStr The log string to be written.
 */
void
writeLog(const std::string logStr)
{
    auto currentTime = std::chrono::system_clock::now();
    std::time_t currentTime_t = std::chrono::system_clock::to_time_t(currentTime);
    std::tm* currentTime_tm = std::localtime(&currentTime_t);
    auto milliseconds =
        std::chrono::duration_cast<std::chrono::milliseconds>(currentTime.time_since_epoch())
            .count() %
        1000;
    std::stringstream ss;
    char buffer[80];
    std::strftime(buffer, sizeof(buffer), "%H-%M-%S", currentTime_tm);
    ss << buffer << "." << std::setfill('0') << std::setw(3) << milliseconds;
    std::string formattedTime = ss.str();

    std::ofstream outputFile;
    outputFile.open(thisFileName + ".cc.log", std::ios::app);

    if (outputFile.is_open())
    {
        outputFile << "[" << formattedTime << "] " << logStr;
        outputFile.close();
        // std::cout << "[" << formattedTime << "]"<< "Text has been written to the file." <<
        // std::endl;
    }
    else
    {
        std::cout << "[" << formattedTime << "]"
                  << "Failed to open or create the file!" << std::endl;
    }
}

void
WifiStateTrace(std::string context, Time start, Time duration, WifiPhyState state)
{
    std::stringstream nowStr;
    Time now = Simulator::Now();
    nowStr << "[" << now.GetNanoSeconds() << "] ";
    std::stringstream startStr;
    startStr << "Start [" << start.GetNanoSeconds() << "] ";
    std::stringstream durationStr;
    durationStr << "Duration [" << duration.GetNanoSeconds() << "] ";
    std::string stateStr;
    switch (state)
    {
    case IDLE:
        stateStr = "IDLE";
        break;
    case CCA_BUSY:
        stateStr = "CCA_BUSY";
        break;
    case TX:
        stateStr = "TX";
        break;
    case RX:
        stateStr = "RX";
        break;
    case SWITCHING:
        stateStr = "SWITCHING";
        break;
    case SLEEP:
        stateStr = "SLEEP";
        break;
    case OFF:
        stateStr = "OFF";
        break;
    default:
        NS_FATAL_ERROR("Invalid state");
        stateStr = "INVALID";
        break;
    }
    writeLog("[" + context + "] " + nowStr.str() + startStr.str() + durationStr.str() + stateStr +
             "\n");
}

/**
 * Connects the trace for the state of each WifiPhy to the given context.
 *
 * @param wifi_dev The WifiNetDevice to connect the trace to.
 * @param context The context string to be appended to the trace name.
 */
void
ConnectContextTrace2Mac(Ptr<WifiNetDevice> wifi_dev, std::string context)
{
    Ptr<WifiMac> wifi_mac = wifi_dev->GetMac();
    int phyid = 0;
    for (Ptr<WifiPhy> wifi_phy : wifi_dev->GetPhys())
    {
        Ptr<WifiPhyStateHelper> wifi_phy_state = wifi_phy->GetState();
        wifi_phy_state->TraceConnect("State",
                                     context + std::to_string(phyid),
                                     MakeCallback(&WifiStateTrace));
        phyid++;
    }
}

/**
 * \param udp true if UDP is used, false if TCP is used
 * \param serverApp a container of server applications
 * \param payloadSize the size in bytes of the packets
 * \return the bytes received by each server application
 */
std::vector<uint64_t>
GetRxBytes(bool udp, const ApplicationContainer& serverApp, uint32_t payloadSize)
{
    std::vector<uint64_t> rxBytes(serverApp.GetN(), 0);
    if (udp)
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = payloadSize * DynamicCast<UdpServer>(serverApp.Get(i))->GetReceived();
        }
    }
    else
    {
        for (uint32_t i = 0; i < serverApp.GetN(); i++)
        {
            rxBytes[i] = DynamicCast<PacketSink>(serverApp.Get(i))->GetTotalRx();
        }
    }
    return rxBytes;
};

/**
 * @brief Prints information about the WiFi network devices.
 *
 * This function takes a NetDeviceContainer and a context string as input.
 * It prints information about the MAC address and IP address of each WiFi network device in the
 * container.
 *
 * @param netDevices The NetDeviceContainer containing the WiFi network devices.
 * @param context The context string to be printed along with the information.
 */
void
wifiNetDeviceInfo(const NetDeviceContainer netDevices, std::string context)
{
    std::cout << context << " info : " << std::endl;
    for (std::size_t i = 0; i < netDevices.GetN(); i++)
    {
        std::cout << std::endl;
        // mac address
        const Ptr<WifiNetDevice> wifiNetDevice = DynamicCast<WifiNetDevice>(netDevices.Get(i));
        Ptr<WifiMac> wifiMac = wifiNetDevice->GetMac();
        std::cout << "\t"
                  << "mac address : " << wifiMac->GetAddress() << std::endl;
        for (std::size_t linkId = 0; linkId < std::max<uint8_t>(wifiNetDevice->GetNPhys(), 1);
             linkId++)
        {
            auto fem = wifiMac->GetFrameExchangeManager(linkId);
            std::cout << "\t\t"
                      << "linkId " << linkId << " mac Address : " << fem->GetAddress() << std::endl;
        }

        // ip address
        std::cout << "\t"
                  << "ip address : " << wifiMac->GetAddress() << std::endl;
        Ptr<Ipv4> ipv4 = (wifiNetDevice->GetNode())->GetObject<Ipv4>();
        NS_ASSERT_MSG(ipv4,
                      "Ipv4AddressHelper::Assign(): NetDevice is associated"
                      " with a node without IPv4 stack installed -> fail "
                      "(maybe need to use InternetStackHelper?)");
        for (std::size_t interfaceNum = 0; interfaceNum < ipv4->GetNInterfaces(); interfaceNum++)
        {
            std::cout << "\t\t"
                      << "interface " << interfaceNum << " : " << std::endl;
            for (std::size_t addressNum = 0; addressNum < ipv4->GetNAddresses(interfaceNum);
                 addressNum++)
            {
                Ipv4InterfaceAddress ipv4InterfaceAddress =
                    ipv4->GetAddress(interfaceNum, addressNum);
                std::cout << "\t\t\t"
                          << "ip address " << addressNum << " : " << ipv4InterfaceAddress.GetLocal()
                          << std::endl;
            }
        }
    }
}

int
main(int argc, char* argv[])
{
    bool useRts{false}; // 开启Rts/Cts机制
    int mcs{0};         // 数据调制格式
    int channelWidth{
        40}; // 信道宽度，MHz，需要注意的是这里的三个信道的参数如果设置的信道带宽过窄，是无法触发OFDMA的mu的特性的
    int gi{3200};              // 符号间隔，仅能取800，1600，3200，单位nanoseconds
    bool udp{true};            // udp开启选项，取否则开启TCP
    uint32_t payloadSize{700}; // udp数据帧中载荷大小
    std::size_t nStations{4};  // sta数量
    std::string dlAckSeqType{"MU-BAR"}; // 在DL数据交换中的数据发送及ACK机制的设定
    double freq{
        5}; // 信道载波频率，对ofdma仿真验证中暂时就不去考虑MLD的特性，因此此处仅使用一个载波频率
    uint16_t channelSwitchDelayUsec{100}; // 物理层切换链路的时间，单位ms
    bool enableUlOfdma{false};            // 开启上行链路的OFDMA
    bool enableBsrp{false};               // OFDMA上行链路中的Bsrp功能
    Time accessReqInterval{
        0}; // MU管理器访问信道的间隔，即使没有DL数据，也会按照这个时间间隔定期访问信道，以便协调各STA的UL-MU传输
    uint16_t mpduBufferSize{256}; // mpdu大小，在ns-3.40版本中该数值最大为256
    // 在2023.1.3的ns-3开发版本中该值最大可以到1024，而且被集成到wifi-mac中进行设置，但是因为pcap无法正确输出，暂时不能使用开发版进行仿真
    double distance{1};        // sta与ap间距离，sta重合应该不是问题
    bool downlink{true};       // 为true时仅开启下行传输，为false时仅开启上行传输
    double simulationTime{10}; // 仿真时间，单位s

    CommandLine cmd(__FILE__);
    cmd.AddValue("useRts", "Enable RTS/CTS", useRts);
    cmd.AddValue("mcs", "MCS value, 0 ~ 11", mcs);
    cmd.AddValue("channelWidth",
                 "Channel width, 40 MHz is better for OFDMA, 20, 40, 80, 160",
                 channelWidth);
    cmd.AddValue("gi", "Guard interval, 800ns, 1600ns, 3200ns", gi);
    cmd.AddValue("udp", "true for use UDP, false for use TCP", udp);
    cmd.AddValue("payloadSize", "udp or tcp Payload size", payloadSize);
    cmd.AddValue("nStations", "Number of stations", nStations);
    cmd.AddValue("dlAckType",
                 "Ack sequence type for DL OFDMA (NO-OFDMA, ACK-SU-FORMAT, MU-BAR, AGGR-MU-BAR)",
                 dlAckSeqType);
    cmd.AddValue("freq", "Whether the link operates in the 2.4, 5 or 6 GHz band", freq);
    cmd.AddValue("channelSwitchDelay",
                 "The PHY channel switch delay in microseconds",
                 channelSwitchDelayUsec);
    cmd.AddValue("enableUlOfdma",
                 "Enable UL OFDMA (useful if DL OFDMA is enabled and TCP is used)",
                 enableUlOfdma);
    cmd.AddValue("enableBsrp",
                 "Enable BSRP (useful if DL and UL OFDMA are enabled and TCP is used)",
                 enableBsrp);
    cmd.AddValue(
        "muSchedAccessReqInterval",
        "Duration of the interval between two requests for channel access made by the MU scheduler",
        accessReqInterval);
    cmd.AddValue("mpduBufferSize",
                 "Size (in number of MPDUs) of the BlockAck buffer, must less than 256",
                 mpduBufferSize);
    cmd.AddValue("distance",
                 "Distance in meters between the station and the access point",
                 distance);
    cmd.AddValue("downlink",
                 "Generate downlink flows if set to true, uplink flows otherwise",
                 downlink);
    cmd.AddValue("simulationTime", "Simulation time in seconds", simulationTime);
    cmd.Parse(argc, argv);

    if (useRts)
    {
        Config::SetDefault("ns3::WifiRemoteStationManager::RtsCtsThreshold", StringValue("0"));
        Config::SetDefault("ns3::WifiDefaultProtectionManager::EnableMuRts", BooleanValue(true));
    }

    if (dlAckSeqType == "ACK-SU-FORMAT")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_BAR_BA_SEQUENCE));
    }
    else if (dlAckSeqType == "MU-BAR")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_TF_MU_BAR));
    }
    else if (dlAckSeqType == "AGGR-MU-BAR")
    {
        Config::SetDefault("ns3::WifiDefaultAckManager::DlMuAckSequenceType",
                           EnumValue(WifiAcknowledgment::DL_MU_AGGREGATE_TF));
    }
    else if (dlAckSeqType != "NO-OFDMA")
    {
        NS_ABORT_MSG("Invalid DL ack sequence type (must be NO-OFDMA, ACK-SU-FORMAT, MU-BAR or "
                     "AGGR-MU-BAR)");
    }

    if (!udp)
    {
        Config::SetDefault("ns3::TcpSocket::SegmentSize", UintegerValue(payloadSize));
    }

    NodeContainer wifiStaNodes;
    wifiStaNodes.Create(nStations);
    NodeContainer wifiApNode;
    wifiApNode.Create(1);

    WifiHelper wifi;
    wifi.SetStandard(WIFI_STANDARD_80211be);

    std::array<std::string, 3> channelStr;
    std::array<FrequencyRange, 3> freqRanges;
    uint8_t nLinks = 0;
    std::string dataModeStr = "EhtMcs" + std::to_string(mcs);
    std::string ctrlRateStr;
    uint64_t nonHtRefRateMbps = EhtPhy::GetNonHtReferenceRate(mcs) / 1e6;
    channelStr[nLinks] = "{0, " + std::to_string(channelWidth) + ", ";

    if (freq == 6)
    {
        channelStr[nLinks] += "BAND_6GHZ, 0}";
        freqRanges[nLinks] = WIFI_SPECTRUM_6_GHZ;
        Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss", DoubleValue(48));
        wifi.SetRemoteStationManager(nLinks,
                                     "ns3::ConstantRateWifiManager",
                                     "DataMode",
                                     StringValue(dataModeStr),
                                     "ControlMode",
                                     StringValue(dataModeStr));
    }
    else if (freq == 5)
    {
        channelStr[nLinks] += "BAND_5GHZ, 0}";
        freqRanges[nLinks] = WIFI_SPECTRUM_5_GHZ;
        ctrlRateStr = "OfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        wifi.SetRemoteStationManager(nLinks,
                                     "ns3::ConstantRateWifiManager",
                                     "DataMode",
                                     StringValue(dataModeStr),
                                     "ControlMode",
                                     StringValue(ctrlRateStr));
    }
    else if (freq == 2.4)
    {
        channelStr[nLinks] += "BAND_2_4GHZ, 0}";
        freqRanges[nLinks] = WIFI_SPECTRUM_2_4_GHZ;
        Config::SetDefault("ns3::LogDistancePropagationLossModel::ReferenceLoss", DoubleValue(40));
        ctrlRateStr = "ErpOfdmRate" + std::to_string(nonHtRefRateMbps) + "Mbps";
        wifi.SetRemoteStationManager(nLinks,
                                     "ns3::ConstantRateWifiManager",
                                     "DataMode",
                                     StringValue(dataModeStr),
                                     "ControlMode",
                                     StringValue(ctrlRateStr));
    }
    else
    {
        std::cout << "Wrong frequency value!" << std::endl;
        return 0;
    }
    nLinks++;

    // 这里就不启动emlsr功能，也就不在mac层中设置emlsr参数

    Ssid ssid = Ssid("ns3-802.11be");

    SpectrumWifiPhyHelper phy(nLinks);
    phy.SetPcapDataLinkType(WifiPhyHelper::DLT_IEEE802_11_RADIO);

    phy.Set("ChannelSwitchDelay", TimeValue(MicroSeconds(channelSwitchDelayUsec)));

    WifiMacHelper mac;
    mac.SetType("ns3::StaWifiMac", "Ssid", SsidValue(ssid));
    for (uint8_t linkId = 0; linkId < nLinks; linkId++)
    {
        phy.Set(linkId, "ChannelSettings", StringValue(channelStr[linkId]));

        auto spectrumChannel = CreateObject<MultiModelSpectrumChannel>();
        auto lossModel = CreateObject<LogDistancePropagationLossModel>();
        spectrumChannel->AddPropagationLossModel(lossModel);
        phy.AddChannel(spectrumChannel, freqRanges[linkId]);
    }
    NetDeviceContainer staDevices;
    staDevices = wifi.Install(phy, mac, wifiStaNodes);

    // 总之，OFDMA使用与否是受AP控制的，STA只能被动的接受
    // 可怜的STA被AP玩弄于股掌之中
    if (dlAckSeqType != "NO-OFDMA")
    {
        mac.SetMultiUserScheduler("ns3::RrMultiUserScheduler",
                                  "EnableUlOfdma",
                                  BooleanValue(enableUlOfdma),
                                  "EnableBsrp",
                                  BooleanValue(enableBsrp),
                                  "AccessReqInterval",
                                  TimeValue(accessReqInterval));
    }
    mac.SetType("ns3::ApWifiMac",
                "EnableBeaconJitter",
                BooleanValue(false),
                "Ssid",
                SsidValue(ssid));
    NetDeviceContainer apDevice;
    apDevice = wifi.Install(phy, mac, wifiApNode);

    // 设置随机数种子，保证每次仿真的结果都相同
    RngSeedManager::SetSeed(1);
    RngSeedManager::SetRun(1);
    int64_t streamNumber = 100;
    streamNumber += wifi.AssignStreams(apDevice, streamNumber);
    streamNumber += wifi.AssignStreams(staDevices, streamNumber);

    // Set guard interval and MPDU buffer size
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/GuardInterval",
                TimeValue(NanoSeconds(gi)));
    Config::Set("/NodeList/*/DeviceList/*/$ns3::WifiNetDevice/HeConfiguration/MpduBufferSize",
                UintegerValue(mpduBufferSize));

    // mobility.
    MobilityHelper mobility;
    Ptr<ListPositionAllocator> positionAlloc = CreateObject<ListPositionAllocator>();

    positionAlloc->Add(Vector(0.0, 0.0, 0.0));
    positionAlloc->Add(Vector(distance, 0.0, 0.0));
    mobility.SetPositionAllocator(positionAlloc);

    mobility.SetMobilityModel("ns3::ConstantPositionMobilityModel");

    mobility.Install(wifiApNode);
    mobility.Install(wifiStaNodes);

    /* Internet stack*/
    InternetStackHelper stack;
    stack.Install(wifiApNode);
    stack.Install(wifiStaNodes);

    Ipv4AddressHelper address;
    address.SetBase("192.168.1.0", "255.255.255.0");
    Ipv4InterfaceContainer staNodeInterfaces;
    Ipv4InterfaceContainer apNodeInterface;

    staNodeInterfaces = address.Assign(staDevices);
    apNodeInterface = address.Assign(apDevice);

    /* Setting applications */
    // 需要注意的是这里是饱和式传输，启用OFDMA后反而会因为控制帧占用信道资源而导致总的网络层数据交换量下降

    ApplicationContainer serverApp;
    auto serverNodes = downlink ? std::ref(wifiStaNodes) : std::ref(wifiApNode);
    Ipv4InterfaceContainer serverInterfaces;
    NodeContainer clientNodes;

    for (std::size_t i = 0; i < nStations; i++)
    {
        serverInterfaces.Add(downlink ? staNodeInterfaces.Get(i) : apNodeInterface.Get(0));
        clientNodes.Add(downlink ? wifiApNode.Get(0) : wifiStaNodes.Get(i));
    }

    if (udp)
    {
        // UDP flow
        uint16_t port = 9;
        UdpServerHelper server(port);
        serverApp = server.Install(serverNodes.get());
        serverApp.Start(Seconds(0.0));
        serverApp.Stop(Seconds(simulationTime + 1));

        for (std::size_t i = 0; i < nStations; i++)
        {
            UdpClientHelper client(serverInterfaces.GetAddress(i), port);
            client.SetAttribute("MaxPackets", UintegerValue(4294967295U));
            client.SetAttribute("Interval", TimeValue(Time("0.00001"))); // packets/s
            client.SetAttribute("PacketSize", UintegerValue(payloadSize));
            ApplicationContainer clientApp = client.Install(clientNodes.Get(i));
            clientApp.Start(Seconds(1.0));
            clientApp.Stop(Seconds(simulationTime + 1));
        }
    }
    else
    {
        // TCP flow
        uint16_t port = 50000;
        Address localAddress(InetSocketAddress(Ipv4Address::GetAny(), port));
        PacketSinkHelper packetSinkHelper("ns3::TcpSocketFactory", localAddress);
        serverApp = packetSinkHelper.Install(serverNodes.get());
        serverApp.Start(Seconds(0.0));
        serverApp.Stop(Seconds(simulationTime + 1));

        for (std::size_t i = 0; i < nStations; i++)
        {
            OnOffHelper onoff("ns3::TcpSocketFactory", Ipv4Address::GetAny());
            onoff.SetAttribute("OnTime", StringValue("ns3::ConstantRandomVariable[Constant=1]"));
            onoff.SetAttribute("OffTime", StringValue("ns3::ConstantRandomVariable[Constant=0]"));
            onoff.SetAttribute("PacketSize", UintegerValue(payloadSize));
            onoff.SetAttribute("DataRate", DataRateValue(1000000000)); // bit/s
            AddressValue remoteAddress(InetSocketAddress(serverInterfaces.GetAddress(i), port));
            onoff.SetAttribute("Remote", remoteAddress);
            ApplicationContainer clientApp = onoff.Install(clientNodes.Get(i));
            clientApp.Start(Seconds(1.0));
            clientApp.Stop(Seconds(simulationTime + 1));
        }
    }

    // 打印一些sta和ap的信息
    wifiNetDeviceInfo(staDevices, "sta");
    wifiNetDeviceInfo(apDevice, "ap");

    ConnectContextTrace2Mac(DynamicCast<WifiNetDevice>(apDevice.Get(0)), "ap");

    std::string pcapFileName = thisFileName + "-ap";
    phy.EnablePcap(pcapFileName, apDevice, true);
    std::string pcapFileName1 = thisFileName + "-sta";
    phy.EnablePcap(pcapFileName1, staDevices, true);

    std::cout << "MCS value"
              << "\t\t"
              << "Channel width"
              << "\t\t"
              << "GI"
              << "\t\t\t"
              << "Throughput" << '\n';

    Simulator::Stop(Seconds(simulationTime + 1));
    Simulator::Run();

    // cumulative number of bytes received by each server application
    std::vector<uint64_t> cumulRxBytes(1, 0);
    cumulRxBytes = GetRxBytes(true, serverApp, payloadSize);
    uint64_t rxBytes =
        std::accumulate(cumulRxBytes.cbegin(), cumulRxBytes.cend(), 0); // 向量中元素累加
    double throughput = (rxBytes * 8) / (simulationTime * 1000000.0);   // Mbit/s

    Simulator::Destroy();

    std::cout << mcs << "\t\t\t" << channelWidth << " MHz\t\t\t" << gi << " ns\t\t\t" << throughput
              << " Mbit/s" << std::endl;
}